"""Logistic pivot fitting for fractal dimensions.

The box‑counting estimator produces a sequence of dimension values
``D_raw(n)`` at successive scales ``n``.  These values often follow a
sigmoidal trajectory as measurement resolution increases.  To
characterise this trajectory we fit a two‑parameter logistic model of
the form

    D(n) = 1 + 2 / (1 + exp[-k * (n - n0)])

where ``k`` controls the steepness of the transition and ``n0``
controls the midpoint (the pivot).  The lower asymptote is fixed at 1
and the upper at 3, reflecting the range of interest for 2–3D
fractal datasets.

This module provides a function to estimate ``k`` and ``n0`` from a
sequence of observed dimensions using non‑linear least squares.  It
also computes the coefficient of determination (R²) to quantify the
goodness of fit.
"""

from __future__ import annotations

import numpy as np
from typing import Tuple

try:
    from scipy.optimize import curve_fit
except ImportError:
    curve_fit = None  # type: ignore

def logistic_model(n: np.ndarray, k: float, n0: float) -> np.ndarray:
    """Two‑parameter logistic model with fixed lower/upper asymptotes.

    Parameters
    ----------
    n : np.ndarray
        Independent variable (scale indices).
    k : float
        Steepness of the logistic transition.
    n0 : float
        Midpoint (pivot) of the logistic curve.

    Returns
    -------
    np.ndarray
        Modelled dimension values at each input ``n``.
    """
    return 1.0 + 2.0 / (1.0 + np.exp(-k * (n - n0)))

def fit_logistic(n_vals: np.ndarray, D_vals: np.ndarray) -> Tuple[float, float, float, np.ndarray]:
    """Fit the logistic model to measured dimensions.

    Parameters
    ----------
    n_vals : np.ndarray
        1D array of scale indices (integers).
    D_vals : np.ndarray
        1D array of measured dimension estimates corresponding to ``n_vals``.

    Returns
    -------
    k : float
        Estimated steepness parameter.
    n0 : float
        Estimated pivot (midpoint) parameter.
    r2 : float
        Coefficient of determination of the fit.
    pred : np.ndarray
        Predicted dimension values from the fitted model.
    """
    # Ensure numpy arrays
    n_vals = np.asarray(n_vals, dtype=float)
    D_vals = np.asarray(D_vals, dtype=float)
    # Initial guesses: k positive, n0 near median of n_vals
    p0 = [1.0, np.median(n_vals)]
    if curve_fit is not None:
        try:
            params, _ = curve_fit(logistic_model, n_vals, D_vals, p0=p0, maxfev=10000)
            k, n0 = params
        except Exception:
            # Fallback: simple least squares grid search
            k, n0 = _fallback_grid_search(n_vals, D_vals)
    else:
        k, n0 = _fallback_grid_search(n_vals, D_vals)
    pred = logistic_model(n_vals, k, n0)
    # Compute R²
    ss_res = np.sum((D_vals - pred) ** 2)
    ss_tot = np.sum((D_vals - D_vals.mean()) ** 2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else np.nan
    return float(k), float(n0), float(r2), pred

def _fallback_grid_search(n_vals: np.ndarray, D_vals: np.ndarray) -> Tuple[float, float]:
    """Fallback optimisation using a coarse grid search.

    In the absence of SciPy, this function performs a very simple
    exhaustive search over a grid of ``k`` and ``n0`` values and
    returns the pair that minimises the sum of squared errors.  The
    search ranges were chosen empirically to capture typical
    behaviours of the logistic curve for this problem.
    """
    ks = np.linspace(0.1, 5.0, 40)
    n0s = np.linspace(n_vals.min(), n_vals.max(), 40)
    best_k, best_n0 = ks[0], n0s[0]
    best_err = np.inf
    for k in ks:
        for n0 in n0s:
            pred = logistic_model(n_vals, k, n0)
            err = np.sum((D_vals - pred) ** 2)
            if err < best_err:
                best_err = err
                best_k, best_n0 = k, n0
    return best_k, best_n0